import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt


# 简化的神经网络模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3)  # 缩小卷积层的输出通道
        self.fc1 = nn.Linear(8 * 26 * 26, 10)  # 调整全连接层的输入和输出尺寸

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = x.view(x.size(0), -1)  # 展平
        x = self.fc1(x)
        return x


# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self, data, labels=None):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if self.labels is not None:
            return self.data[idx], self.labels[idx]
        else:
            return self.data[idx], -1  # 无标签数据


# 半监督训练函数
def pseudo_labeling_training(model, labeled_loader, unlabeled_loader, optimizer, device, threshold=0.95):
    model.train()
    labeled_loss_value = 0
    pseudo_loss_value = 0

    for (labeled_data, labeled_labels), (unlabeled_data, _) in zip(labeled_loader, unlabeled_loader):
        labeled_data, labeled_labels = labeled_data.to(device), labeled_labels.to(device)
        unlabeled_data = unlabeled_data.to(device)

        # 1. 有标签数据训练
        optimizer.zero_grad()
        labeled_output = model(labeled_data)
        labeled_loss = F.cross_entropy(labeled_output, labeled_labels)
        labeled_loss.backward()

        # 2. 无标签数据伪标签生成
        unlabeled_output = model(unlabeled_data)
        probs = F.softmax(unlabeled_output, dim=1)
        max_probs, pseudo_labels = torch.max(probs, dim=1)

        # 伪标签置信度筛选
        pseudo_mask = max_probs > threshold  # 置信度大于阈值的数据作为伪标签
        if pseudo_mask.sum() > 0:
            pseudo_labels = pseudo_labels[pseudo_mask]
            unlabeled_data_pseudo = unlabeled_data[pseudo_mask]

            # 3. 使用伪标签数据进行训练（确保无标签数据参与反向传播）
            optimizer.zero_grad()  # 清除之前的梯度
            pseudo_output = model(unlabeled_data_pseudo)
            pseudo_loss = F.cross_entropy(pseudo_output, pseudo_labels)
            pseudo_loss.backward()  # 计算反向梯度

        optimizer.step()  # 更新模型参数

        # 累加损失用于展示
        labeled_loss_value += labeled_loss.item()
        if pseudo_mask.sum() > 0:
            pseudo_loss_value += pseudo_loss.item()

    return labeled_loss_value / len(labeled_loader), pseudo_loss_value / len(unlabeled_loader)


# 模拟数据
num_labeled = 1000
num_unlabeled = 5000
data_dim = (1, 28, 28)  # 28x28 灰度图像
num_classes = 10

labeled_data = torch.randn(num_labeled, *data_dim)
labeled_labels = torch.randint(0, num_classes, (num_labeled,))
unlabeled_data = torch.randn(num_unlabeled, *data_dim)

labeled_dataset = CustomDataset(labeled_data, labeled_labels)
unlabeled_dataset = CustomDataset(unlabeled_data)

labeled_loader = DataLoader(labeled_dataset, batch_size=32, shuffle=True)  # 缩小批量大小
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=32, shuffle=True)  # 缩小批量大小

# 模型、优化器和设备设置
device = torch.device("cpu")  # 临时使用 CPU
model = SimpleCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练过程并记录损失
num_epochs = 10
labeled_losses = []
pseudo_losses = []

for epoch in range(num_epochs):
    labeled_loss, pseudo_loss = pseudo_labeling_training(model, labeled_loader, unlabeled_loader, optimizer, device)
    labeled_losses.append(labeled_loss)
    pseudo_losses.append(pseudo_loss)
    print(f"Epoch [{epoch + 1}/{num_epochs}] | Labeled Loss: {labeled_loss:.4f} | Pseudo Loss: {pseudo_loss:.4f}")

# 绘制损失曲线
plt.plot(range(num_epochs), labeled_losses, label='Labeled Loss')
plt.plot(range(num_epochs), pseudo_losses, label='Pseudo Label Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training Losses Over Epochs')
plt.show()

# 展示伪标签生成效果（可视化一些样本的伪标签预测结果）
model.eval()
with torch.no_grad():
    sample_unlabeled_data = unlabeled_data[:10].to(device)
    output = model(sample_unlabeled_data)
    probs = F.softmax(output, dim=1)
    _, predicted_labels = torch.max(probs, dim=1)

    # 展示预测的标签
    print("Generated Pseudo Labels for Samples:")
    print(predicted_labels)

    # 假设这些是伪标签预测的图片
    fig, axes = plt.subplots(2, 5, figsize=(12, 5))
    for i, ax in enumerate(axes.flat):
        # 将tensor转换为NumPy数组
        img = sample_unlabeled_data[i].cpu().numpy().squeeze()  # 转为NumPy数组
        ax.imshow(img, cmap='gray')  # 使用灰度显示图像
        ax.set_title(f"Pred: {predicted_labels[i].item()}")
        ax.axis('off')
    plt.show()
